-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Made masked losses compatible with masked nans #18829
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18829 +/- ##
==========================================
+ Coverage 79.30% 79.37% +0.06%
==========================================
Files 336 336
Lines 34549 34775 +226
Branches 6799 6841 +42
==========================================
+ Hits 27400 27603 +203
- Misses 5567 5590 +23
Partials 1582 1582
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR.
keras/losses/loss.py
Outdated
allowed = {"sum_over_batch_size", "sum", None, "none"} | ||
raise ValueError( | ||
"Invalid value for argument `reduction`. " | ||
f"Expected on of {allowed}. Received: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one of
keras/metrics/reduction_metrics.py
Outdated
@@ -13,8 +13,11 @@ def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype): | |||
if sample_weight is not None: | |||
sample_weight = ops.cast(sample_weight, dtype=dtype) | |||
if mask is not None: | |||
sample_weight = losses.loss.apply_mask( | |||
sample_weight, mask, dtype=dtype, reduction="sum" | |||
sample_weight, mask = losses.loss.squeeze_to_same_rank( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not refactor apply_mask
instead?
batch_size = ops.count_nonzero(mask) | ||
values_sum = ops.sum(values) | ||
# safe divide | ||
return ops.cond( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will add significant overhead to the computation. Is there a better way that doesn't involve a cond?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are plenty of ways to do essentially this, but is a cond
on scalars that expensive? What about a where
? They seemed the clearest, but other options would be:
- use
batch_size = maximum(batch_size, 1)
- should give identical results (batch size is an integer, and when it's 0 the numerator should be zero) - use
batch_size = batch_size + epsilon
- closer to the current implementation, though IMO wrong.
@fchollet refactored to use a re-introduced |
The reason has to do with parallelization. Conditional branches are much harder to handle than a single serial op even if it has more flops. @haifeng-jin can you advise here on how to test the performance impact of this change for a couple of standard models on GPU? I'd like to compare the |
@jackd here is what I did for benchmarking a PR. Each of the notebooks are for one of the backends, either before the change or after the change. You will see the perf at the end of the notebook. |
Thanks, Haifeng! @jackd can you use the same code to benchmark the impact of this change? |
This isn't a priority for me, and I've already spent a lot longer on this than I intended to. If anyone else wants to take this up feel free, otherwise would a modified PR with just the simplified version (re-replacing masking with multiplication by zero) be accepted? |
Running on colab T4
|
@dugujiujian1999 maybe I'm missing something, but why do run with lower times have a higher times per step? e.g. |
@jackd i don't know. It takes less time after the patch. i use the code there: |
@jackd , Can you please rebase the code to follow latest code structure like |
CBF at this point, someone else can take over if they'd like. |
NaNs are a great way to ensure certain values aren't used (e.g. those that are associated with masked values). This change ensures that masked values are correctly masked (set to zero, even when nan) rather than multiplied by zero (which leaves nans as nans).
This PR also (IMO) greatly simplifies the masking / weighting loss implementation. Test coverage is also improved.